# =============================================================================
#  Analysis 1: This example script demonstrates single cell tracking analysis using an example from the Single Cell Tracking Challenge Dataset, (http://celltrackingchallenge.net/)
# =============================================================================

"""
1. Load the video frames and create video.
"""
from skimage.exposure import rescale_intensity
import numpy as np 
from MOSES.Utility_Functions.file_io import detect_files
from skimage.io import imread

vid_folder = '../Videos/PhC-C2DH-U373/01'  
# load individal video frames saved as .tif files. 
video_files, video_fnames = detect_files(vid_folder, ext='.tif')
# read individual frames and concatenate frames after rescaling intensity 
vidstack = np.concatenate([rescale_intensity(imread(f))[None,:] for f in video_files], axis=0)

"""
2. Compute the Superpixel Tracks
"""
from MOSES.Optical_Flow_Tracking.superpixel_track import compute_grayscale_vid_superpixel_tracks_FB
import scipy.io as spio

n_spixels = 1000
# set the motion extraction parameters
opt_flow_params = {'pyr_scale':0.5, 'levels':7, 'winsize':25, 'iterations':3, 'poly_n':5, 'poly_sigma':1.2, 'flags':0}
# compute superpixel tracks
optflow, tracks_F, tracks_B = compute_grayscale_vid_superpixel_tracks_FB(vidstack, opt_flow_params, n_spixels=n_spixels)
# save the tracks 
spio.savemat('vid01_%d_spixels.mat' %(n_spixels), {'tracks_F':tracks_F, 
                                                   'tracks_B':tracks_B})

"""
3. Segment Individual Cells in initial Frame
"""
from skimage.filters import threshold_otsu, threshold_triangle
from skimage.measure import label
from skimage.morphology import remove_small_objects, binary_closing, disk
from scipy.ndimage.morphology import binary_fill_holes
import pylab as plt 

# first frame of video
frame0 = vidstack[0]

# basic quick cell segmentation by thresholding intensities
# a) determine intensity threshold
thresh = np.mean(frame0) + .5*np.std(frame0)
binary = frame0 >= thresh
# b) steps to refine the basic thresholding of a)
binary = remove_small_objects(binary, 200)
binary = binary_closing(binary, disk(5))
binary = binary_fill_holes(binary)
binary = remove_small_objects(binary, 1000)

# c) connected component analysis to label binary cell segmentation with integers
cells_frame0 = label(binary)

# =============================================================================
# Show the first frame of the video and the segmentation mask
# =============================================================================
# get the coordinates of each cell and plot at their centroid the integer labels. 
plt.figure()
plt.imshow(frame0, cmap='gray')
plt.grid('off')
plt.axis('off')
#plt.savefig('frame0.svg', bbox_inches='tight')
plt.show()

plt.figure()
plt.imshow(cells_frame0)
plt.grid('off')
plt.axis('off')
#plt.savefig('frame0_segmented_cells.svg', bbox_inches='tight')
plt.show()

"""
4. Assign superpixel tracks to segmented cells
"""
# define a function 'assign_spixels' that can be called to assing superpixel tracks to integer labelled mask
def assign_spixels(spixeltracks, mask):
    
    uniq_regions = np.unique(mask)
    initial_pos = spixeltracks[:,0,:]
    
    bool_mask = []
    # starts from 1-index not 0-index as the first is assigned to the background.
    for uniq_region in uniq_regions[1:]:
        mask_region = mask == uniq_region
        bool_mask.append(mask_region[initial_pos[:,0], initial_pos[:,1]])
        
    bool_mask = np.vstack(bool_mask)

    return bool_mask   

cell_spixels = assign_spixels(tracks_F, cells_frame0)

"""
5. Compute the single track to describe the single cell motion.
"""
def find_characteristic_motion(tracks):
    
    # This function chooses the single longest superpixel track to summarise the single cell motion.
    disps_tracks = tracks[:, 1:, :] - tracks[:, :-1, :]
    disps_mag = np.sum(np.sqrt(disps_tracks[:,:,0]**2 + disps_tracks[:,:,1]**2), axis=1)
    rank = np.argsort(disps_mag)[::-1]

    return tracks[rank[0]]

single_cell_tracks = []
# iterate through each segmentation and store the single track
for i in range(len(cell_spixels)):
    
    # for cell i which superpixel ids belong capture it
    cell_spixel = cell_spixels[i]
    # for cell i retrive the corresponding tracks 
    cell_tracks = tracks_F[cell_spixel]
    # find the single track that summarise the motion of all associated superpixel tracks
    single_cell_track = find_characteristic_motion(cell_tracks)
    single_cell_tracks.append(single_cell_track[None,:])

single_cell_tracks = np.vstack(single_cell_tracks)

"""
6. Plot the segmented tracks 
"""
from MOSES.Visualisation_Tools.track_plotting import plot_tracks
import seaborn as sns 
cell_colors = sns.color_palette('Set1', n_colors=len(single_cell_tracks))

sns.palplot(cell_colors)

fig, ax = plt.subplots()
plt.title('Single Cell Tracks', fontsize=16)
ax.imshow(vidstack[0], cmap='gray')

for ii in range(len(single_cell_tracks)):
    single_cell_track = single_cell_tracks[ii]
    plot_tracks(single_cell_track[None,:], ax, color=cell_colors[ii], lw=3)

plt.grid('off')
plt.axis('off')
#fig.savefig('frame0_segmented_cell_tracks.svg', bbox_inches='tight')
plt.show()

# additionally plot all the extracted tracks without any segmentation for comparison
fig, ax = plt.subplots()
plt.title('Superpixel Tracks', fontsize=16)
plt.imshow(frame0, cmap='gray')
plot_tracks(tracks_F, ax=ax, color='r')
plt.show()

"""
7. Advanced Analysis: Analysing the motion saliency map to determine motion sources and sinks 
"""
from MOSES.Motion_Analysis.mesh_statistics_tools import compute_motion_saliency_map
from skimage.filters import gaussian

dist_thresh = 30
spixel_size = tracks_B[1,0,1]-tracks_B[1,0,0]

saliency_F, saliency_spatial_time_F = compute_motion_saliency_map(tracks_F, dist_thresh=dist_thresh, shape=frame0.shape, max_frame=None, filt=1, filt_size=spixel_size)
saliency_B, saliency_spatial_time_B = compute_motion_saliency_map(tracks_B, dist_thresh=dist_thresh, shape=frame0.shape, max_frame=None, filt=1, filt_size=spixel_size)
# Use Gaussian smoothing to create a more continuous heatmap
saliency_F_smooth = gaussian(saliency_F, spixel_size/2)
saliency_B_smooth = gaussian(saliency_B, spixel_size/2)

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15,15))
ax[0].set_title('Superpixel Tracks')
ax[0].imshow(frame0, cmap='gray')
plot_tracks(tracks_F, ax=ax[0], color='r'); ax[0].axis('off')
ax[1].set_title('Motion Saliency (Forward)')
ax[1].imshow(vidstack[-1], cmap='gray')
ax[1].imshow(saliency_F_smooth, cmap='coolwarm', alpha=0.7); ax[1].axis('off')
ax[2].set_title('Motion Saliency (Backward)')
ax[2].imshow(frame0, cmap='gray')
ax[2].imshow(saliency_B_smooth, cmap='coolwarm', alpha=0.7); ax[2].axis('off')
#fig.savefig('Example-motion-saliency-analysis.svg', dpi=300, bbox_inches='tight')
plt.show()
